import torch
import torchvision.models as models
from torch import nn
from torchvision.models import ResNet18_Weights  # 导入 ResNet18_Weights

class FontModel(nn.Module):
    def __init__(self, num_classes=3, weights=ResNet18_Weights.DEFAULT):
        super().__init__()
        self.backbone = models.resnet18(weights=weights)  # 使用 weights 参数
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),  # 添加 Dropout 防止过拟合
            nn.Linear(in_features, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)